{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import jieba\n",
    "import Spider\n",
    "\n",
    "def score(item, query):\n",
    "    score = 0\n",
    "    # TODO 对query查询的分词避免重复\n",
    "    for keyword in jieba.cut(query.lower()):\n",
    "        title_score = item[1].lower().count(keyword.lower())\n",
    "        content_score = item[2].lower().count(keyword.lower())\n",
    "        score += title_score * 5 + content_score * 3\n",
    "    return score\n",
    "\n",
    "\n",
    "class MySearcherC8V0:\n",
    "    \"\"\"\n",
    "    第七次课升级的搜索类版本：\n",
    "    1、__init__()初始化过程加载自定义分词词典\n",
    "    2、build_cache()改用jieba.cut_for_search进行分词\n",
    "    3、search()对查询分词\n",
    "    4、search()对分词结果取posting\n",
    "    5、search()对posting lists进行合并(交集)\n",
    "    6、build_cache()将posting保存格式改成只用doc_id(方便集合运算)\n",
    "    7、rank()实现对候选文档打分排序\n",
    "    8、score()实现对查询中包含的多词统计词频计分\n",
    "    \"\"\"\n",
    "    def __init__(self, scale: int=1):\n",
    "        self.docs = list()\n",
    "        self.load_data()\n",
    "        if scale > 1:\n",
    "            self.docs *= scale  # 文档规模倍增，用于测试搜索速度\n",
    "        self.cache = dict()\n",
    "        self.vocab = set()\n",
    "        self.lower_preprocess()\n",
    "        jieba.load_userdict('./dict.txt')\n",
    "        self.build_cache()\n",
    "\n",
    "    def load_data(self, data_file_name='./news_list.pkl'):\n",
    "        if os.path.exists(data_file_name):\n",
    "            self.docs = Spider.pickle_load(data_file_name)\n",
    "        else:\n",
    "            Spider.pickle_save(data_file_name)\n",
    "            self.docs = Spider.pickle_load(data_file_name)\n",
    "\n",
    "    def search(self, query):\n",
    "        result = None\n",
    "        for keyword in jieba.cut(query.lower()):\n",
    "            if keyword in self.cache:\n",
    "                if result is None:\n",
    "                    result = self.cache[keyword]\n",
    "                else:\n",
    "                    result = result & self.cache[keyword]\n",
    "            else:\n",
    "                result = set()\n",
    "                break\n",
    "        if result is None:\n",
    "            result = set()\n",
    "        sorted_result = self.rank(query, result)\n",
    "        return sorted_result\n",
    "\n",
    "    def rank(self, query, result_set):\n",
    "        result = list()\n",
    "        for doc_id in result_set:\n",
    "            result.append([doc_id, score(self.docs[doc_id], query)])\n",
    "\n",
    "        result.sort(key=lambda x: x[1], reverse=True)\n",
    "        return result\n",
    "\n",
    "    # def render_search_result(self, keyword):\n",
    "    #     count = 0\n",
    "    #     for item in self.search(keyword):\n",
    "    #         count += 1\n",
    "    #         print(f'{count}[{item[1]}] {highlight(self.docs[item[0]][1], keyword)}')\n",
    "\n",
    "    def build_cache(self):\n",
    "        \"\"\"用分词（用文档过滤词库）的方式初始化缓存（构建索引）\"\"\"\n",
    "        doc_id = 0\n",
    "        for doc in self.docs:\n",
    "            doc_word_set = set()\n",
    "            for word in jieba.cut_for_search(doc[3]):\n",
    "                if word not in doc_word_set:\n",
    "                    result_item = doc_id\n",
    "                    if word not in self.cache:\n",
    "                        self.cache[word] = {result_item}\n",
    "                    else:\n",
    "                        self.cache[word].add(result_item)\n",
    "                    self.vocab.add(word)\n",
    "                    doc_word_set.add(word)\n",
    "            doc_id += 1\n",
    "\n",
    "    def lower_preprocess(self):\n",
    "        for doc_id in range(len(self.docs)):\n",
    "            self.docs[doc_id].append(\n",
    "                (self.docs[doc_id][1] + ' ' + self.docs[doc_id][2]).lower())\n",
    "\n",
    "    def simple_test(self):\n",
    "        assert(len(self.search('tiktok')) > 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\10633\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 0.857 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wall time: 3.56 s\n"
     ]
    }
   ],
   "source": [
    "%time searcher_v0 = MySearcherC8V0()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "def highlight(item, query: str, side_len: int = 12) -> str:\n",
    "    positions = list()\n",
    "    segments = list()\n",
    "    i = 0\n",
    "    content_lower = item[2].lower()\n",
    "    len_content_lower = len(content_lower)\n",
    "    for keyword in jieba.cut(query):\n",
    "        idx = content_lower.find(keyword.lower())\n",
    "        positions.append(idx)\n",
    "    positions.sort()\n",
    "    while i < len(positions):\n",
    "        start_pos = max(positions[i] - side_len, 0)\n",
    "        end_pos = min(positions[i] + side_len, len_content_lower)\n",
    "        while (i < len(positions) - 1) and (positions[i+1] - positions[i] < side_len*2):\n",
    "            end_pos = min(positions[i+1] + side_len, len_content_lower)\n",
    "            i += 1\n",
    "        start_ellipsis = '...' if start_pos > 0 else ''\n",
    "        end_ellipsis = '...' if end_pos < len_content_lower else ''\n",
    "        segments.append(start_ellipsis + item[2][start_pos: end_pos] + end_ellipsis)\n",
    "        i += 1\n",
    "    result = item[1] + '<br/>' + ''.join(segments)\n",
    "    # if idx >= 0:\n",
    "    #     ori_word = item[1][idx:idx+(len(query))]\n",
    "    #     result = item[1].replace(ori_word, f'<span style=\"color:red\";>{ori_word}</span>')\n",
    "    return result\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}